import abc
from typing import Optional, Any, List, NamedTuple, Generic, TypeVar, Tuple


# Shield result:
# real_action corresponds to an action that will actually take place in the environment
# It should be a safe joint action, because this action will actually occur
# However, sometimes, the shield may
#
# can allow action (None) or replace it (return tuple of joint action),
# can punish (negative) or leave unchanged (zero). NOTE this applies to the transition which is actually taken
# If replaced, can generate "fake" extra transition for training (true) or only the transition actually taken (false)
# If the extra transition is created, how much punishment (negative) to assign to it, or leave unchanged (zero)


class AgentUpdate(NamedTuple):
    action: Any
    reward_modifier: Optional[int] = None
    absolute_reward: Optional[int] = None

    def get_modified_reward(self, rew_from_env):
        if self.absolute_reward is not None:
            return self.absolute_reward
        elif self.reward_modifier is not None:
            return self.reward_modifier + rew_from_env
        else:
            return rew_from_env


class AgentResult(NamedTuple):
    real_action: AgentUpdate
    augmented_action: Optional[AgentUpdate] = None


ShieldResult = List[AgentResult]

# If there is some centralized training component, all augmented_actions should be treated as a single joint action

T = TypeVar("T")
S = TypeVar("S")


class Shield(abc.ABC, Generic[T, S]):
    def __init__(self, env: T, punish_unsafe_orig_action=False, unsafe_action_punishment=0):
        self.punish_unsafe_orig_action = punish_unsafe_orig_action
        self.unsafe_action_punishment = unsafe_action_punishment
        self.env = env

    @abc.abstractmethod
    def get_initial_shield_state(self, state, initial_joint_obs) -> S:
        pass

    @abc.abstractmethod
    def evaluate_joint_action(self, state, joint_obs, proposed_action, shield_state: S) -> Tuple[ShieldResult, S]:
        pass

    def replace_action_agent_result(self, unsafe_action, safe_action) -> AgentResult:
        if self.punish_unsafe_orig_action:
            return AgentResult(real_action=AgentUpdate(action=safe_action),
                               augmented_action=AgentUpdate(action=unsafe_action,
                                                            reward_modifier=self.unsafe_action_punishment))
        else:
            return AgentResult(real_action=AgentUpdate(action=safe_action))
